{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Character based RNN language model trained on 'The Complete Works of William Shakespeare'\n",
    "Based on http://karpathy.github.io/2015/05/21/rnn-effectiveness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Load 'The Complete Works of William Shakespeare'\n",
    "using Knet\n",
    "include(Knet.dir(\"data\",\"gutenberg.jl\"))\n",
    "trn,tst,chars = shakespeare()\n",
    "map(summary,(trn,tst,chars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [],
   "source": [
    "# Print a sample\n",
    "println(string(chars[trn[1020:1210]]...)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "RNNTYPE = :lstm\n",
    "BATCHSIZE = 256\n",
    "SEQLENGTH = 100\n",
    "INPUTSIZE = 168\n",
    "VOCABSIZE = 84\n",
    "HIDDENSIZE = 334\n",
    "NUMLAYERS = 1\n",
    "DROPOUT = 0.0\n",
    "LR=0.001\n",
    "BETA_1=0.9\n",
    "BETA_2=0.999\n",
    "EPS=1e-08\n",
    "EPOCHS = 30;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Minibatch data\n",
    "function mb(a)\n",
    "    N = div(length(a),BATCHSIZE)\n",
    "    x = reshape(a[1:N*BATCHSIZE],N,BATCHSIZE)' # reshape full data to (B,N) with contiguous rows\n",
    "    minibatch(x[:,1:N-1], x[:,2:N], SEQLENGTH) # split into (B,T) blocks \n",
    "end\n",
    "dtrn,dtst = mb(trn),mb(tst)\n",
    "map(length, (dtrn,dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Define model\n",
    "function initmodel()\n",
    "    w(d...)=KnetArray(xavier(Float32,d...))\n",
    "    b(d...)=KnetArray(zeros(Float32,d...))\n",
    "    r,wr = rnninit(INPUTSIZE,HIDDENSIZE,rnnType=RNNTYPE,numLayers=NUMLAYERS,dropout=DROPOUT)\n",
    "    wx = w(INPUTSIZE,VOCABSIZE)\n",
    "    wy = w(VOCABSIZE,HIDDENSIZE)\n",
    "    by = b(VOCABSIZE,1)\n",
    "    return r,wr,wx,wy,by\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Given the current character, predict the next character\n",
    "function predict(ws,xs,hx,cx;pdrop=0)\n",
    "    r,wr,wx,wy,by = ws\n",
    "    x = wx[:,xs]                                    # xs=(B,T) x=(X,B,T)\n",
    "    x = dropout(x,pdrop)\n",
    "    y,hy,cy = rnnforw(r,wr,x,hx,cx,hy=true,cy=true) # y=(H,B,T) hy=cy=(H,B,L)\n",
    "    y = dropout(y,pdrop)\n",
    "    y2 = reshape(y,size(y,1),size(y,2)*size(y,3))   # y2=(H,B*T)\n",
    "    return wy*y2.+by, hy, cy\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Define loss and its gradient\n",
    "function loss(w,x,y,h;o...)\n",
    "    py,hy,cy = predict(w,x,h...;o...)\n",
    "    h[1],h[2] = getval(hy),getval(cy)\n",
    "    return nll(py,y)\n",
    "end\n",
    "\n",
    "lossgradient = gradloss(loss);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "function train(model,data,optim)\n",
    "    hiddens = Any[nothing,nothing]\n",
    "    Σ,N=0,0\n",
    "    for (x,y) in data\n",
    "        grads,loss1 = lossgradient(model,x,y,hiddens;pdrop=DROPOUT)\n",
    "        update!(model, grads, optim)\n",
    "        Σ,N=Σ+loss1,N+1\n",
    "    end\n",
    "    return Σ/N\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "function test(model,data)\n",
    "    hiddens = Any[nothing,nothing]\n",
    "    Σ,N=0,0\n",
    "    for (x,y) in data\n",
    "        Σ,N = Σ+loss(model,x,y,hiddens), N+1\n",
    "    end\n",
    "    return Σ/N\n",
    "end; "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Train model or load from file if exists\n",
    "using JLD\n",
    "model=optim=nothing; knetgc()\n",
    "if !isfile(\"shakespeare.jld\")\n",
    "    model = initmodel()\n",
    "    optim = optimizers(model, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);    info(\"Training...\")\n",
    "    @time for epoch in 1:EPOCHS\n",
    "        @time trnloss = train(model,dtrn,optim) # ~18 seconds\n",
    "        @time tstloss = test(model,dtst)        # ~0.5 seconds\n",
    "        println((:epoch, epoch, :trnppl, exp(trnloss), :tstppl, exp(tstloss)))\n",
    "    end\n",
    "    save(\"shakespeare.jld\",\"model\",model)\n",
    "else\n",
    "    model = load(\"shakespeare.jld\",\"model\")\n",
    "end\n",
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Sample from trained model\n",
    "function generate(model,n)\n",
    "    function sample(y)\n",
    "        p,r=Array(exp.(y-logsumexp(y))),rand()\n",
    "        for j=1:length(p); (r -= p[j]) < 0 && return j; end\n",
    "    end\n",
    "    h,c = nothing,nothing\n",
    "    x = findfirst(chars,'\\n')\n",
    "    for i=1:n\n",
    "        y,h,c = predict(model,[x],h,c)\n",
    "        x = sample(y)\n",
    "        print(chars[x])\n",
    "    end\n",
    "    println()\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "generate(model,1000)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.4",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
